Day 2 pm: Bambi

In this session we will learn how to use the Python library bambi to fit formula-based Bayesian regression models.

Imports

First we import all the python packages that we will need.

import arviz as az
import bambi as bmb
import polars as pl
import numpy as np
from matplotlib import pyplot as plt

Hello world example

To demonstrate how bambi works we’ll start with a very simple linear regression model, where the variate \(y\) is an unconstrained real number that is predicted by a single covariate \(b\).

We can simulate some measurements from this model like this

A_HW = 0.2
B_HW = 1.7
SIGMA_HW = 0.5

def simulate_hw(x: float, a,  b: float, sigma: float):
    "Simulate a measurement given covariate x, weight b and error sd sigma."
    yhat = a + x * b
    return yhat + np.random.normal(0, scale=sigma)

x_hw = np.linspace(-0.5, 1.5, 10)
y_hw = np.array([simulate_hw(x_i, A_HW, B_HW, SIGMA_HW) for x_i in x_hw])
data_hw = pl.DataFrame({"x": x_hw, "y": y_hw})

f, ax = plt.subplots()
ax.scatter(x_hw, y_hw, marker="x", color="black", label="simulated observation")
ax.set(xlabel="x", ylabel="y")
plt.show()

We can implement this model using bambi with the very simple formula `y ~ x”

formula_hw = "y ~ x"
model_hw = bmb.Model(formula_hw, data=data_hw.to_pandas())
model_hw
       Formula: y ~ x
        Family: gaussian
          Link: mu = identity
  Observations: 10
        Priors: 
    target = mu
        Common-level effects
            Intercept ~ Normal(mu: 1.1428, sigma: 3.6137)
            x ~ Normal(mu: 0.0, sigma: 4.457)
        
        Auxiliary parameters
            sigma ~ HalfStudentT(nu: 4.0, sigma: 1.1379)

To perform Bayesian inference, we use the model object’s fit method:

results_hw = model_hw.fit()
az.summary(results_hw)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, Intercept, x]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
sigma 0.566 0.164 0.322 0.875 0.004 0.004 2361.0 2515.0 1.0
Intercept 0.328 0.237 -0.118 0.771 0.004 0.005 3133.0 2338.0 1.0
x 1.637 0.282 1.088 2.154 0.005 0.005 2999.0 2618.0 1.0

More relevant example

As a second example, we will fit the model introduced in yesterday’s session.

Suppose we have five strains that we want to evaluate for their ability to ferment a protein. To test these abilities we perform 4 biological replicates per strain, each of which we test 5 times.

TRUE_PRODUCTIVITY = {
    "a": 0.49,
    "b": 0.51,
    "c": 0.53,
    "d": 0.55,
    "e": 0.57
}
N_BIOLOGICAL_REPLICATE = 4
N_TECHNICAL_REPLICATE = 5
BIOLOGICAL_VARIATION = 0.1
TECHNICAL_VARIATION = 0.01


def simulate_fermentation(prod, bio_effect, tv):
    return np.exp(np.log(prod) + bio_effect + np.random.normal(0, scale=tv))

rows = []
for strain, prod in TRUE_PRODUCTIVITY.items():
    for row_br in range(N_BIOLOGICAL_REPLICATE):
        bio_effect = np.random.normal(0, BIOLOGICAL_VARIATION)
        for row_tr in range(N_TECHNICAL_REPLICATE):
            rows.append(
                {
                  "strain": strain,
                  "biological_replicate": f"{strain}-{row_br}",
                  "technical_replicate": f"{strain}-{row_br}-{row_tr}",
                  "y": simulate_fermentation(
                      prod,
                      bio_effect,
                      TECHNICAL_VARIATION,
                  ),
                }
            )
data_bio = pl.from_records(rows).with_columns(log_y=np.log(pl.col("y")))
data_bio.head()
shape: (5, 5)
strain biological_replicate technical_replicate y log_y
str str str f64 f64
"a" "a-0" "a-0-0" 0.478268 -0.737583
"a" "a-0" "a-0-1" 0.476265 -0.741782
"a" "a-0" "a-0-2" 0.484644 -0.724341
"a" "a-0" "a-0-3" 0.47724 -0.739737
"a" "a-0" "a-0-4" 0.480249 -0.73345

To specify the model we do more or less the same as before, except that this time our formula is "log_y ~ 0 + strain + (1|biological_replicate)" indicating a model with no global intercept (this is what the 0 at the start of the right hand side does) and separate intercept parameters per strain and per biological replicate, with the biological replicate intercepts modelled hierarchically.

Since our model has slightly unusual scales, we also supply some custom priors. Note the nested structure for the "1|biological_replicate" prior.

formula_bio = "log_y ~ 0 + strain + (1|biological_replicate)"
bio_var_prior = bmb.Prior("HalfNormal", sigma=0.2)
br_effect_prior = bmb.Prior("Normal", mu=0.0, sigma=bio_var_prior)
priors = {
    "strain": bmb.Prior("Normal", mu=-0.7, sigma=0.3),
    "sigma": bmb.Prior("HalfNormal", sigma=0.01),
    "1|biological_replicate": br_effect_prior,
}
model_bio = bmb.Model(formula_bio, data=data_bio.to_pandas(), priors=priors)
model_bio
       Formula: log_y ~ 0 + strain + (1|biological_replicate)
        Family: gaussian
          Link: mu = identity
  Observations: 100
        Priors: 
    target = mu
        Common-level effects
            strain ~ Normal(mu: -0.7, sigma: 0.3)
        
        Group-level effects
            1|biological_replicate ~ Normal(mu: 0.0, sigma: HalfNormal(sigma: 0.2))
        
        Auxiliary parameters
            sigma ~ HalfNormal(sigma: 0.01)

Fitting and inspecting goes the same as before, but to save space we avoid printing the 1|biological_replicate parameters. This is a handy arviz trick!

results_bio = model_bio.fit()
az.summary(
    results_bio,
    var_names="~1|biological_replicate",
    filter_vars="regex"
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, strain, 1|biological_replicate_sigma, 1|biological_replicate_offset]
/home/nicholas/projects/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/pymc/step_methods/hmc/quadpotential.py:316: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
sigma 0.010 0.001 0.009 0.012 0.000 0.000 2902.0 2239.0 1.0
strain[a] -0.736 0.038 -0.812 -0.667 0.001 0.001 1925.0 1821.0 1.0
strain[b] -0.650 0.038 -0.722 -0.580 0.001 0.001 1767.0 1913.0 1.0
strain[c] -0.693 0.038 -0.768 -0.625 0.001 0.001 1960.0 2008.0 1.0
strain[d] -0.590 0.040 -0.662 -0.512 0.001 0.001 2252.0 1892.0 1.0
strain[e] -0.564 0.039 -0.639 -0.492 0.001 0.001 1708.0 1477.0 1.0

Now we can check that the strain intercepts roughly match the simulation inputs.

prod_mean = np.exp(results_bio.posterior["strain"]).mean(dim=["chain", "draw"])
pl.DataFrame(
    {
        "strain": TRUE_PRODUCTIVITY.keys(),
        "true_productivity": TRUE_PRODUCTIVITY.values(),
        "posterior_mean": prod_mean.values
    }
)
shape: (5, 3)
strain true_productivity posterior_mean
str f64 f64
"a" 0.49 0.479508
"b" 0.51 0.522182
"c" 0.53 0.500474
"d" 0.55 0.554984
"e" 0.57 0.569367
Exercise

Try to find the probability, according to our model, that strain “a” is less productive than strain “c”.